{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "uxCkB_DXTHzf"
      },
      "outputs": [],
      "source": [
        "# Copyright 2024 Google LLC\n",
        "#\n",
        "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "#     https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Hny4I-ODTIS6"
      },
      "source": [
        "# Create a Photoshop Document with Image Segmentation on Vertex AI\n",
        "\n",
        "<table align=\"left\">\n",
        "  <td style=\"text-align: center\">\n",
        "    <a href=\"https://colab.research.google.com/github/GoogleCloudPlatform/generative-ai/blob/main/vision/use-cases/image_segmentation_layers.ipynb\">\n",
        "      <img width=\"32px\" src=\"https://www.gstatic.com/pantheon/images/bigquery/welcome_page/colab-logo.svg\" alt=\"Google Colaboratory logo\"><br> Run in Colab\n",
        "    </a>\n",
        "  </td>\n",
        "  <td style=\"text-align: center\">\n",
        "    <a href=\"https://console.cloud.google.com/vertex-ai/colab/import/https:%2F%2Fraw.githubusercontent.com%2FGoogleCloudPlatform%2Fgenerative-ai%2Fmain%2Fvision%2Fuse-cases%2Fimage_segmentation_layers.ipynb\">\n",
        "      <img width=\"32px\" src=\"https://lh3.googleusercontent.com/JmcxdQi-qOpctIvWKgPtrzZdJJK-J3sWE1RsfjZNwshCFgE_9fULcNpuXYTilIR2hjwN\" alt=\"Google Cloud Colab Enterprise logo\"><br> Run in Colab Enterprise\n",
        "    </a>\n",
        "  </td>\n",
        "  <td style=\"text-align: center\">\n",
        "    <a href=\"https://console.cloud.google.com/vertex-ai/workbench/deploy-notebook?download_url=https://raw.githubusercontent.com/GoogleCloudPlatform/generative-ai/main/vision/use-cases/image_segmentation_layers.ipynb\">\n",
        "      <img src=\"https://lh3.googleusercontent.com/UiNooY4LUgW_oTvpsNhPpQzsstV5W8F7rYgxgGBD85cWJoLmrOzhVs_ksK_vgx40SHs7jCqkTkCk=e14-rj-sc0xffffff-h130-w32\" alt=\"Vertex AI logo\"><br> Open in Vertex AI Workbench\n",
        "    </a>\n",
        "  </td>    \n",
        "  <td style=\"text-align: center\">\n",
        "    <a href=\"https://github.com/GoogleCloudPlatform/generative-ai/blob/main/vision/use-cases/image_segmentation_layers.ipynb\">\n",
        "      <img width=\"32px\" src=\"https://raw.githubusercontent.com/primer/octicons/refs/heads/main/icons/mark-github-24.svg\" alt=\"GitHub logo\"><br> View on GitHub\n",
        "    </a>\n",
        "  </td>\n",
        "</table>\n",
        "\n",
        "<div style=\"clear: both;\"></div>\n",
        "\n",
        "<b>Share to:</b>\n",
        "\n",
        "<a href=\"https://www.linkedin.com/sharing/share-offsite/?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/vision/use-cases/image_segmentation_layers.ipynb\" target=\"_blank\">\n",
        "  <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/8/81/LinkedIn_icon.svg\" alt=\"LinkedIn logo\">\n",
        "</a>\n",
        "\n",
        "<a href=\"https://bsky.app/intent/compose?text=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/vision/use-cases/image_segmentation_layers.ipynb\" target=\"_blank\">\n",
        "  <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/7/7a/Bluesky_Logo.svg\" alt=\"Bluesky logo\">\n",
        "</a>\n",
        "\n",
        "<a href=\"https://twitter.com/intent/tweet?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/vision/use-cases/image_segmentation_layers.ipynb\" target=\"_blank\">\n",
        "  <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/5/5a/X_icon_2.svg\" alt=\"X logo\">\n",
        "</a>\n",
        "\n",
        "<a href=\"https://reddit.com/submit?url=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/vision/use-cases/image_segmentation_layers.ipynb\" target=\"_blank\">\n",
        "  <img width=\"20px\" src=\"https://redditinc.com/hubfs/Reddit%20Inc/Brand/Reddit_Logo.png\" alt=\"Reddit logo\">\n",
        "</a>\n",
        "\n",
        "<a href=\"https://www.facebook.com/sharer/sharer.php?u=https%3A//github.com/GoogleCloudPlatform/generative-ai/blob/main/vision/use-cases/image_segmentation_layers.ipynb\" target=\"_blank\">\n",
        "  <img width=\"20px\" src=\"https://upload.wikimedia.org/wikipedia/commons/5/51/Facebook_f_logo_%282019%29.svg\" alt=\"Facebook logo\">\n",
        "</a>            "
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "84f0f73a0f76"
      },
      "source": [
        "| | |\n",
        "|-|-|\n",
        "|Author(s) | [Katie Nguyen](https://github.com/katiemn) |"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-nLS57E2TO5y"
      },
      "source": [
        "## Overview\n",
        "\n",
        "### Image Segmentation\n",
        "\n",
        "Image Segmentation on Vertex AI brings Google's state of the art segmentation models to developers as a scalable and reliable service.\n",
        "\n",
        "With the Vertex AI Image Segmentation API, developers can choose from five different modes to segment images and build AI products.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Gb8FlC1uULsE"
      },
      "source": [
        "In this tutorial, you will learn how to use the Vertex AI API to interact with the Image Segmentation model to:\n",
        "\n",
        "- Segment images using different modes to create image masks\n",
        "- Turn those image masks to individual layers in a PSD file\n",
        "- Save the PSD file to a Cloud Storage bucket"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mvKl-BtQTRiQ"
      },
      "source": [
        "## Get Started"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "l2IULNJBtaGS"
      },
      "source": [
        "### Install Vertex AI SDK for Python and Wand"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "PPmd1BWXthqG"
      },
      "outputs": [],
      "source": [
        "!sudo apt-get install libmagickwand-dev\n",
        "\n",
        "%pip install --upgrade --user --quiet google-cloud-aiplatform Wand"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ozjbBiWUuXeG"
      },
      "source": [
        "### Restart runtime\n",
        "\n",
        "To use the newly installed packages in this Jupyter runtime, you must restart the runtime. You can do this by running the cell below, which restarts the current kernel.\n",
        "\n",
        "The restart might take a minute or longer. After it's restarted, continue to the next step."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-sW9xQiMufrZ"
      },
      "outputs": [],
      "source": [
        "import IPython\n",
        "\n",
        "app = IPython.Application.instance()\n",
        "app.kernel.do_shutdown(True)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2N9y_BrHUt_I"
      },
      "source": [
        "<div class=\"alert alert-block alert-warning\">\n",
        "<b>⚠️ The kernel is going to restart. Please wait until it is finished before continuing to the next step. ⚠️</b>\n",
        "</div>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "opUxT_k5TdgP"
      },
      "source": [
        "### Authenticate your notebook environment (Colab only)\n",
        "\n",
        "If you are running this notebook on Google Colab, run the following cell to authenticate your environment."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "id": "vbNgv4q1T2Mi"
      },
      "outputs": [],
      "source": [
        "import sys\n",
        "\n",
        "if \"google.colab\" in sys.modules:\n",
        "    from google.colab import auth\n",
        "\n",
        "    auth.authenticate_user()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ybBXSukZkgjg"
      },
      "source": [
        "### Set Google Cloud project information and initialize Vertex AI API\n",
        "\n",
        "To get started using Vertex AI, you must have an existing Google Cloud project and [enable the Vertex AI API](https://console.cloud.google.com/flows/enableapi?apiid=aiplatform.googleapis.com).\n",
        "\n",
        "Learn more about [setting up a project and a development environment](https://cloud.google.com/vertex-ai/docs/start/cloud-environment)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "id": "q7YvbXXdtzDT"
      },
      "outputs": [],
      "source": [
        "from google.cloud import aiplatform\n",
        "\n",
        "PROJECT_ID = \"[your-project-id]\"  # @param {type:\"string\"}\n",
        "LOCATION = \"us-central1\"  # @param {type:\"string\"}\n",
        "\n",
        "aiplatform.init(project=PROJECT_ID, location=LOCATION)\n",
        "\n",
        "api_regional_endpoint = f\"{LOCATION}-aiplatform.googleapis.com\"\n",
        "client_options = {\"api_endpoint\": api_regional_endpoint}\n",
        "client = aiplatform.gapic.PredictionServiceClient(client_options=client_options)\n",
        "\n",
        "model_endpoint = f\"projects/{PROJECT_ID}/locations/{LOCATION}/publishers/google/models/image-segmentation-001\""
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6ncLgoOYVl-b"
      },
      "source": [
        "### Import libraries"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "id": "bE5g6hrVWNw3"
      },
      "outputs": [],
      "source": [
        "import base64\n",
        "import io\n",
        "import math\n",
        "import os\n",
        "from random import randrange\n",
        "import re\n",
        "import typing\n",
        "\n",
        "import IPython\n",
        "from PIL import Image as PIL_Image\n",
        "from PIL import ImageOps as PIL_ImageOps\n",
        "from google.cloud import storage\n",
        "import ipywidgets as widgets\n",
        "import matplotlib.pyplot as plt\n",
        "from vertexai.preview.vision_models import Image as VertexAI_Image\n",
        "from wand.image import Image as Wand_Image"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "sWK6SPaAZ5hu"
      },
      "source": [
        "### Define helper functions"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "id": "Ju_PctW22NUl"
      },
      "outputs": [],
      "source": [
        "# Parses the mask bytes from the response and converts it to an Image PIL object\n",
        "\n",
        "\n",
        "def prediction_to_mask_pil(prediction) -> PIL_Image:\n",
        "    encoded_mask_string = prediction[\"bytesBase64Encoded\"]\n",
        "    mask_bytes = base64.b64decode(encoded_mask_string)\n",
        "    mask_pil = PIL_Image.open(io.BytesIO(mask_bytes))\n",
        "    mask_pil.thumbnail((4096, 4096))\n",
        "    return mask_pil\n",
        "\n",
        "\n",
        "# Displays a PIL image horizontally next to a generated mask from the response\n",
        "def display_horizontally(input_images: list, figsize: tuple[int, int] = (20, 5)):\n",
        "    rows: int = math.ceil(len(input_images) / 4)  # Display at most 4 images per row\n",
        "    cols: int = min(\n",
        "        len(input_images) + 1, 4\n",
        "    )  # Adjust columns based on the number of images\n",
        "    fig, axis = plt.subplots(nrows=rows, ncols=cols, figsize=figsize)\n",
        "\n",
        "    for i, ax in enumerate(axis.flat):\n",
        "        if i < len(input_images):\n",
        "            cmap = \"gray\" if i > 0 else None\n",
        "            ax.imshow(input_images[i], cmap)\n",
        "            # Adjust the axis aspect ratio to maintain image proportions\n",
        "            ax.set_aspect(\"equal\")\n",
        "            # Disable axis ticks\n",
        "            ax.set_xticks([])\n",
        "            ax.set_yticks([])\n",
        "            ax.axis(\"off\")\n",
        "        else:\n",
        "            # Hide empty subplots\n",
        "            ax.axis(\"off\")\n",
        "\n",
        "    # Adjust the layout to minimize whitespace between subplots.\n",
        "    plt.tight_layout()\n",
        "    plt.show()\n",
        "\n",
        "\n",
        "def display_image(\n",
        "    image: VertexAI_Image,\n",
        "    max_width: int = 4096,\n",
        "    max_height: int = 4096,\n",
        ") -> None:\n",
        "    pil_image = typing.cast(PIL_Image.Image, image._pil_image)\n",
        "    if pil_image.mode != \"RGB\":\n",
        "        # RGB is supported by all Jupyter environments (e.g. RGBA is not yet)\n",
        "        pil_image = pil_image.convert(\"RGB\")\n",
        "    image_width, image_height = pil_image.size\n",
        "    if max_width < image_width or max_height < image_height:\n",
        "        # Resize to display a smaller notebook image\n",
        "        pil_image = PIL_ImageOps.contain(pil_image, (max_width, max_height))\n",
        "    IPython.display.display(pil_image)\n",
        "\n",
        "\n",
        "# Constructs a Vertex AI PredictRequest for the Image Segmentation model\n",
        "def call_vertex_image_segmentation(\n",
        "    gcs_uri=None,\n",
        "    mode=\"foreground\",\n",
        "    prompt=None,\n",
        "):\n",
        "    instances = []\n",
        "    if gcs_uri:\n",
        "        instances.append(\n",
        "            {\n",
        "                \"image\": {\"gcsUri\": gcs_uri},\n",
        "            }\n",
        "        )\n",
        "    if prompt:\n",
        "        instances[0][\"prompt\"] = prompt\n",
        "\n",
        "    parameters = {\"mode\": mode}\n",
        "    response = client.predict(\n",
        "        endpoint=model_endpoint, instances=instances, parameters=parameters\n",
        "    )\n",
        "\n",
        "    return response"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "R45VRKWInfQQ"
      },
      "source": [
        "### Select an image to segment from a Google Cloud Storage URI"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "whY3dM8XEGgG"
      },
      "outputs": [],
      "source": [
        "file_path = \"gs://\"  # @param {type:\"string\"}\n",
        "\n",
        "# Load the image file as Image object\n",
        "image_file = VertexAI_Image.load_from_file(file_path)\n",
        "display_image(image_file)\n",
        "\n",
        "image_file.save(\"original.png\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fU32286ooc8Q"
      },
      "source": [
        "### Segment images using different modes\n",
        "\n",
        "You can generate image masks with different Image Segmentation features by setting the `mode` field to one of the available options:\n",
        "* **Foreground**: Generate a mask of the segmented foreground of the image.\n",
        "* **Background**: Generate a mask of the segmented background of the image.\n",
        "* **Semantic**: Select the items in an image to segment from a set of 194 classes.\n",
        "* **Prompt**: Use an open-vocabulary text prompt to guide the image segmentation.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mBLJtICO8iMQ"
      },
      "source": [
        "### Foreground segmentation request"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "c9N8l0oo_cWs"
      },
      "outputs": [],
      "source": [
        "gcs_uri = file_path\n",
        "mode = \"foreground\"\n",
        "prompt = None  # Prompt to guide segmentation for `semantic` and `prompt` modes\n",
        "\n",
        "response = call_vertex_image_segmentation(gcs_uri, mode, prompt)\n",
        "\n",
        "MASK_PIL = prediction_to_mask_pil(response.predictions[0])\n",
        "MASK_PIL.save(\"foreground.png\")\n",
        "BACKGROUND_PIL = PIL_Image.open(\"original.png\")\n",
        "display_horizontally([BACKGROUND_PIL, MASK_PIL])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "--7rofOb95hT"
      },
      "source": [
        "### Background segmentation request"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "JVtC3lFAGoAu"
      },
      "outputs": [],
      "source": [
        "gcs_uri = file_path\n",
        "mode = \"background\"\n",
        "prompt = None  # Prompt to guide segmentation for `semantic` and `prompt` modes\n",
        "\n",
        "response = call_vertex_image_segmentation(gcs_uri, mode, prompt)\n",
        "\n",
        "MASK_PIL = prediction_to_mask_pil(response.predictions[0])\n",
        "MASK_PIL.save(\"background.png\")\n",
        "BACKGROUND_PIL = PIL_Image.open(\"original.png\")\n",
        "display_horizontally([BACKGROUND_PIL, MASK_PIL])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "U9pfcnNsGtcv"
      },
      "source": [
        "### Semantic segmentation request\n",
        "\n",
        "Specify the objects to segment from the set of 194 classes. For your convenience, the classes have been arranged into seven separate categories. Run the cell below and select your classes. To select multiple options from the same category, press ctrl or command and click on your selections."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "RRt9J4cS3qnt"
      },
      "outputs": [],
      "source": [
        "from IPython.display import display\n",
        "\n",
        "home_and_furniture = widgets.SelectMultiple(\n",
        "    options=[\n",
        "        \"oven\",\n",
        "        \"toaster\",\n",
        "        \"ottoman\",\n",
        "        \"sink\",\n",
        "        \"wardrobe\",\n",
        "        \"refrigerator\",\n",
        "        \"chest_of_drawers\",\n",
        "        \"dishwasher\",\n",
        "        \"bookshelf\",\n",
        "        \"armchair\",\n",
        "        \"toilet\",\n",
        "        \"counter_other\",\n",
        "        \"bathtub\",\n",
        "        \"bathroom_counter\",\n",
        "        \"shower\",\n",
        "        \"kitchen_island\",\n",
        "        \"hair_dryer\",\n",
        "        \"door\",\n",
        "        \"couch\",\n",
        "        \"toothbrush\",\n",
        "        \"light_other\",\n",
        "        \"lamp\",\n",
        "        \"sconce\",\n",
        "        \"nightstand\",\n",
        "        \"microwave\",\n",
        "        \"bed\",\n",
        "        \"ceiling\",\n",
        "        \"mirror\",\n",
        "        \"cup\",\n",
        "        \"shelf\",\n",
        "        \"knife\",\n",
        "        \"stairs\",\n",
        "        \"fork\",\n",
        "        \"spoon\",\n",
        "        \"curtain_other\",\n",
        "        \"cabinet\",\n",
        "        \"bowl\",\n",
        "        \"television\",\n",
        "        \"fireplace\",\n",
        "        \"tray\",\n",
        "        \"floor\",\n",
        "        \"stove\",\n",
        "        \"range_hood\",\n",
        "        \"towel\",\n",
        "        \"plate\",\n",
        "        \"rug_floormat\",\n",
        "        \"wall\",\n",
        "        \"window\",\n",
        "        \"washer_dryer\",\n",
        "    ],\n",
        "    value=[],\n",
        "    rows=5,\n",
        "    description=\"Home\",\n",
        "    disabled=False,\n",
        ")\n",
        "\n",
        "food = widgets.SelectMultiple(\n",
        "    options=[\n",
        "        \"broccoli\",\n",
        "        \"carrot\",\n",
        "        \"hot_dog\",\n",
        "        \"pizza\",\n",
        "        \"donut\",\n",
        "        \"cake\",\n",
        "        \"fruit_other\",\n",
        "        \"food_other\",\n",
        "        \"bottle\",\n",
        "        \"wine_glass\",\n",
        "        \"banana\",\n",
        "        \"apple\",\n",
        "        \"sandwich\",\n",
        "        \"orange\",\n",
        "    ],\n",
        "    value=[],\n",
        "    rows=5,\n",
        "    description=\"Food\",\n",
        "    disabled=False,\n",
        ")\n",
        "\n",
        "outdoor_and_recreation = widgets.SelectMultiple(\n",
        "    options=[\n",
        "        \"road\",\n",
        "        \"mountain_hill\",\n",
        "        \"snow\",\n",
        "        \"rock\",\n",
        "        \"sidewalk_pavement\",\n",
        "        \"frisbee\",\n",
        "        \"runway\",\n",
        "        \"skis\",\n",
        "        \"terrain\",\n",
        "        \"snowboard\",\n",
        "        \"sports_ball\",\n",
        "        \"baseball_bat\",\n",
        "        \"baseball_glove\",\n",
        "        \"skateboard\",\n",
        "        \"surfboard\",\n",
        "        \"tennis_racket\",\n",
        "        \"net\",\n",
        "        \"tunnel\",\n",
        "        \"bridge\",\n",
        "        \"tent\",\n",
        "        \"awning\",\n",
        "        \"river_lake\",\n",
        "        \"sea\",\n",
        "        \"bus\",\n",
        "        \"bench\",\n",
        "        \"train\",\n",
        "        \"bike_rack\",\n",
        "        \"vegetation\",\n",
        "        \"truck\",\n",
        "        \"waterfall\",\n",
        "        \"bicycle\",\n",
        "        \"trailer\",\n",
        "        \"sky\",\n",
        "        \"car\",\n",
        "        \"traffic_sign\",\n",
        "        \"boat_ship\",\n",
        "        \"autorickshaw\",\n",
        "        \"traffic_light\",\n",
        "        \"motorcycle\",\n",
        "        \"airplane\",\n",
        "    ],\n",
        "    value=[],\n",
        "    rows=5,\n",
        "    description=\"Outdoor\",\n",
        "    disabled=False,\n",
        ")\n",
        "\n",
        "office_and_work = widgets.SelectMultiple(\n",
        "    options=[\n",
        "        \"storage_tank\",\n",
        "        \"desk\",\n",
        "        \"conveyor_belt\",\n",
        "        \"suitcase\",\n",
        "        \"chair_other\",\n",
        "        \"swivel_chair\",\n",
        "        \"laptop\",\n",
        "        \"whiteboard\",\n",
        "        \"keyboard\",\n",
        "        \"mouse\",\n",
        "    ],\n",
        "    value=[],\n",
        "    rows=5,\n",
        "    description=\"Office\",\n",
        "    disabled=False,\n",
        ")\n",
        "clothing_and_accessories = widgets.SelectMultiple(\n",
        "    options=[\"backpack\", \"bag\", \"tie\", \"apparel\"],\n",
        "    value=[],\n",
        "    rows=5,\n",
        "    description=\"Clothing\",\n",
        "    disabled=False,\n",
        ")\n",
        "\n",
        "animals = widgets.SelectMultiple(\n",
        "    options=[\n",
        "        \"bird\",\n",
        "        \"cat\",\n",
        "        \"dog\",\n",
        "        \"horse\",\n",
        "        \"sheep\",\n",
        "        \"cow\",\n",
        "        \"elephant\",\n",
        "        \"bear\",\n",
        "        \"zebra\",\n",
        "        \"giraffe\",\n",
        "        \"animal_other\",\n",
        "    ],\n",
        "    value=[],\n",
        "    rows=5,\n",
        "    description=\"Animals\",\n",
        "    disabled=False,\n",
        ")\n",
        "\n",
        "miscellaneous = widgets.SelectMultiple(\n",
        "    options=[\n",
        "        \"pool_table\",\n",
        "        \"umbrella\",\n",
        "        \"barrel\",\n",
        "        \"case\",\n",
        "        \"book\",\n",
        "        \"crib\",\n",
        "        \"box\",\n",
        "        \"kite\",\n",
        "        \"basket\",\n",
        "        \"clock\",\n",
        "        \"fan\",\n",
        "        \"vase\",\n",
        "        \"scissors\",\n",
        "        \"plaything_other\",\n",
        "        \"stool\",\n",
        "        \"teddy_bear\",\n",
        "        \"seat\",\n",
        "        \"base\",\n",
        "        \"trash_can\",\n",
        "        \"painting\",\n",
        "        \"sculpture\",\n",
        "        \"pier_wharf\",\n",
        "        \"potted_plant\",\n",
        "        \"poster\",\n",
        "        \"column\",\n",
        "        \"bulletin_board\",\n",
        "        \"fountain\",\n",
        "        \"building\",\n",
        "        \"chandelier\",\n",
        "        \"radiator\",\n",
        "        \"table\",\n",
        "        \"stage\",\n",
        "        \"arcade_machine\",\n",
        "        \"banner\",\n",
        "        \"gravel\",\n",
        "        \"flag\",\n",
        "        \"platform\",\n",
        "        \"blanket\",\n",
        "        \"remote\",\n",
        "        \"escalator\",\n",
        "        \"playingfield\",\n",
        "        \"cell phone\",\n",
        "        \"railroad\",\n",
        "        \"shower_curtain\",\n",
        "        \"fire_hydrant\",\n",
        "        \"pillow\",\n",
        "        \"parking_meter\",\n",
        "        \"road_barrier\",\n",
        "        \"water_other\",\n",
        "        \"mailbox\",\n",
        "        \"swimming_pool\",\n",
        "        \"person\",\n",
        "        \"cctv_camera\",\n",
        "        \"billboard\",\n",
        "        \"rider_other\",\n",
        "        \"junction_box\",\n",
        "        \"bicyclist\",\n",
        "        \"pole\",\n",
        "        \"motorcyclist\",\n",
        "        \"slow_wheeled_object\",\n",
        "        \"fence\",\n",
        "        \"window_blind\",\n",
        "        \"paper\",\n",
        "        \"streetlight\",\n",
        "        \"railing_banister\",\n",
        "        \"guard_rail\",\n",
        "    ],\n",
        "    value=[],\n",
        "    rows=5,\n",
        "    description=\"Miscellaneous\",\n",
        "    disabled=False,\n",
        ")\n",
        "\n",
        "display(home_and_furniture)\n",
        "display(food)\n",
        "display(outdoor_and_recreation)\n",
        "display(office_and_work)\n",
        "display(clothing_and_accessories)\n",
        "display(animals)\n",
        "display(miscellaneous)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "m3PX4mOOecgf"
      },
      "source": [
        "Combine all your segmentation class selections into a single string for the request."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "NxqwxEef_sej"
      },
      "outputs": [],
      "source": [
        "item_string = \",\".join(\n",
        "    home_and_furniture.value\n",
        "    + food.value\n",
        "    + outdoor_and_recreation.value\n",
        "    + office_and_work.value\n",
        "    + clothing_and_accessories.value\n",
        "    + animals.value\n",
        "    + miscellaneous.value\n",
        ")\n",
        "print(item_string)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "BVDUNa7UevoN"
      },
      "source": [
        "Regardless of the number of classes, a semantic segmentation request will return a single image mask with all detected items from the request."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Aar3Pn3yG75T"
      },
      "outputs": [],
      "source": [
        "gcs_uri = file_path\n",
        "mode = \"semantic\"\n",
        "prompt = item_string\n",
        "\n",
        "response = call_vertex_image_segmentation(gcs_uri, mode, prompt)\n",
        "\n",
        "MASK_PIL = prediction_to_mask_pil(response.predictions[0])\n",
        "MASK_PIL.save(\"semantic.png\")\n",
        "BACKGROUND_PIL = PIL_Image.open(\"original.png\")\n",
        "display_horizontally([BACKGROUND_PIL, MASK_PIL])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CO4q2sacIydg"
      },
      "source": [
        "### Open vocabulary segmentation request\n",
        "\n",
        "Provide a prompt to guide the image segmentation. Unlike other modes, an open vocabulary request will produce multiple image masks based on the prompt."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "unrHSqhqHmlI"
      },
      "outputs": [],
      "source": [
        "# Delete local prompt based masks from previous runs\n",
        "pattern = re.compile(\"prompt*\")\n",
        "for file in os.listdir(\".\"):\n",
        "    if pattern.match(file):\n",
        "        os.remove(file)\n",
        "\n",
        "gcs_uri = file_path\n",
        "mode = \"prompt\"\n",
        "prompt = \"[your-prompt]\"  # @param {type:\"string\"}\n",
        "\n",
        "response = call_vertex_image_segmentation(gcs_uri, mode, prompt)\n",
        "\n",
        "BACKGROUND_PIL = PIL_Image.open(\"original.png\")\n",
        "images = [BACKGROUND_PIL]\n",
        "for i in range(len(response.predictions)):\n",
        "    MASK_PIL = prediction_to_mask_pil(response.predictions[i])\n",
        "    MASK_PIL.save(\"prompt\" + str(i) + \".png\")\n",
        "    images.append(MASK_PIL)\n",
        "\n",
        "display_horizontally(images)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "UrRBCQ4lf78V"
      },
      "source": [
        "### Select masks to apply to PSD file\n",
        "\n",
        "Run the following cell to generate a checklist of all possible segmentation modes you may have previously generated. Then, select all modes you would like to be included in the final PSD file. All of the specified image masks will be included as separate layers."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "FPMmGO5TFLYG"
      },
      "outputs": [],
      "source": [
        "from IPython.display import display\n",
        "\n",
        "foreground_checkbox = widgets.Checkbox(\n",
        "    value=True, description=\"Foreground Mask\", disabled=False\n",
        ")\n",
        "background_checkbox = widgets.Checkbox(\n",
        "    value=True, description=\"Background Mask\", disabled=False\n",
        ")\n",
        "semantic_checkbox = widgets.Checkbox(\n",
        "    value=True, description=\"Semantic Mask\", disabled=False\n",
        ")\n",
        "prompt_checkbox = widgets.Checkbox(\n",
        "    value=True, description=\"Prompt Mask\", disabled=False\n",
        ")\n",
        "\n",
        "\n",
        "display(foreground_checkbox)\n",
        "display(background_checkbox)\n",
        "display(semantic_checkbox)\n",
        "display(prompt_checkbox)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "rFAh_A01gc_G"
      },
      "source": [
        "### Add selected mask images as layers\n",
        "\n",
        "Once the layers are added, you will save the final PSD file."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 13,
      "metadata": {
        "id": "VXCM_QOz3Uln"
      },
      "outputs": [],
      "source": [
        "with Wand_Image(filename=\"original.png\") as img:\n",
        "    img.read(filename=\"original.png\")\n",
        "\n",
        "    if foreground_checkbox.value:\n",
        "        img.read(filename=\"foreground.png\")\n",
        "    if background_checkbox.value:\n",
        "        img.read(filename=\"background.png\")\n",
        "    if semantic_checkbox.value:\n",
        "        img.read(filename=\"semantic.png\")\n",
        "    if prompt_checkbox.value:\n",
        "        pattern = re.compile(\"prompt*\")\n",
        "        for file in os.listdir(\".\"):\n",
        "            if pattern.match(file):\n",
        "                img.read(filename=file)\n",
        "\n",
        "    img.save(filename=\"output.psd\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "C--fF0Wjg8Dp"
      },
      "source": [
        "### Upload the PSD file to Google Cloud Storage bucket"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "LW4VpNDsemLm"
      },
      "outputs": [],
      "source": [
        "prefix = \"psd_\" + str(randrange(10000, 100000))\n",
        "bucket_name = \"[your-bucket-name]\"  # @param {type:\"string\"}\n",
        "\n",
        "storage_client = storage.Client()\n",
        "bucket = storage_client.bucket(bucket_name)\n",
        "blob = bucket.blob(prefix)\n",
        "\n",
        "blob.upload_from_filename(\"output.psd\")\n",
        "print(\"Uploaded \" + prefix)"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "name": "image_segmentation_layers.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
